"""
Phase 6: Dynamics & Robustness
==============================
1. 5-year lagged Z → trilemma indices
2. First differences (ΔZ → Δtrilemma)
3. OECD vs non-OECD subsample split
4. Pre/post GFC split (2008)
5. Excluding financial centers
6. Logit vs LPM for mi_sacrifice and is_peg
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

# OECD members (as of 2024, 38 countries)
OECD = {
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK', 'EST',
    'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN',
    'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL', 'NOR', 'POL', 'PRT',
    'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR', 'GBR', 'USA',
}

FINANCIAL_CENTERS = {'LUX', 'IRL', 'HKG', 'SGP', 'CHE', 'NLD', 'BEL'}


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, se, p):
    s = stars(p)
    return f"{val:.4f}{s}", f"({se:.4f})"


def run_panel_gls(df, y_var, x_vars, label):
    """Run PanelGLS and return results dict."""
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    gls = PanelGLS()
    y = sub[y_var].values
    X = sub[x_vars].values
    try:
        gls.fit(y, X, sub['iso3'].values, sub['year'].values)
    except Exception as e:
        print(f"  {label}: GLS failed ({e}), skipping")
        return None

    result = {
        'model': label,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
        'rho': gls.rho,
    }
    print(f"\n  {label} (N={gls.n_obs}, R²={gls.r_squared:.4f})")
    for i, name in enumerate(x_vars):
        sig = stars(gls.pvalues[i])
        print(f"    {name:30s} {gls.beta[i]:8.4f} ({gls.se[i]:.4f}) {sig}")
        result[f'{name}_coef'] = gls.beta[i]
        result[f'{name}_se'] = gls.se[i]
        result[f'{name}_p'] = gls.pvalues[i]

    return result


def run_logit(df, y_var, x_vars, label):
    """Pooled logit with inverse-Hessian SEs (manual implementation)."""
    from scipy.optimize import minimize
    from scipy import stats as sp_stats

    cols = [y_var] + x_vars + ['iso3']
    sub = df[cols].dropna()
    if len(sub) < 50:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    y = sub[y_var].values.astype(float)
    X = np.column_stack([np.ones(len(sub)), sub[x_vars].values.astype(float)])
    n, k = X.shape

    # Check for sufficient variation
    if y.sum() < 5 or (1 - y).sum() < 5:
        print(f"  {label}: insufficient outcome variation (events={y.sum():.0f}), skipping")
        return None

    # Standardize X for numerical stability (except intercept)
    x_means = X[:, 1:].mean(axis=0)
    x_stds = X[:, 1:].std(axis=0)
    x_stds[x_stds == 0] = 1
    X_std = X.copy()
    X_std[:, 1:] = (X[:, 1:] - x_means) / x_stds

    def neg_log_likelihood(beta):
        z = X_std @ beta
        z = np.clip(z, -30, 30)
        p = 1 / (1 + np.exp(-z))
        p = np.clip(p, 1e-12, 1 - 1e-12)
        return -np.sum(y * np.log(p) + (1 - y) * np.log(1 - p))

    def gradient(beta):
        z = X_std @ beta
        z = np.clip(z, -30, 30)
        p = 1 / (1 + np.exp(-z))
        return -X_std.T @ (y - p)

    # Initial values
    beta0 = np.zeros(k)

    try:
        result_opt = minimize(neg_log_likelihood, beta0, jac=gradient,
                              method='BFGS', options={'maxiter': 1000, 'gtol': 1e-6})
        if not result_opt.success:
            print(f"  {label}: logit optimization did not converge, using result anyway")

        beta_std = result_opt.x

        # Transform back to original scale
        beta = np.zeros(k)
        beta[1:] = beta_std[1:] / x_stds
        beta[0] = beta_std[0] - np.sum(beta_std[1:] * x_means / x_stds)

        # Hessian for standard errors
        z = X @ beta
        z = np.clip(z, -30, 30)
        p = 1 / (1 + np.exp(-z))
        W = p * (1 - p)
        H = X.T @ (X * W[:, None])

        try:
            V = np.linalg.inv(H)
            se = np.sqrt(np.diag(V))
        except np.linalg.LinAlgError:
            se = np.full(k, np.nan)

        t_stats = beta / se
        pvalues = 2 * (1 - sp_stats.norm.cdf(np.abs(t_stats)))

        # Pseudo-R² (McFadden)
        ll_model = -neg_log_likelihood(beta_std)
        p_bar = y.mean()
        ll_null = n * (p_bar * np.log(p_bar + 1e-12) + (1 - p_bar) * np.log(1 - p_bar + 1e-12))
        pseudo_r2 = 1 - ll_model / ll_null if ll_null != 0 else 0

        # Marginal effects at mean
        z_mean = X.mean(axis=0) @ beta
        p_mean = 1 / (1 + np.exp(-z_mean))
        mfx = beta[1:] * p_mean * (1 - p_mean)

    except Exception as e:
        print(f"  {label}: logit failed ({e}), skipping")
        return None

    res = {
        'model': label,
        'n_obs': n,
        'n_countries': sub['iso3'].nunique(),
        'r_squared': pseudo_r2,
        'rho': 0.0,
    }

    print(f"\n  {label} (N={n}, Pseudo-R²={pseudo_r2:.4f}) [Logit]")
    for i, name in enumerate(x_vars):
        sig = stars(pvalues[i + 1])
        print(f"    {name:30s} B={beta[i+1]:8.4f} (se={se[i+1]:.4f}) {sig}  "
              f"[MFX={mfx[i]:.5f}]")
        res[f'{name}_coef'] = mfx[i]  # Report marginal effects
        res[f'{name}_se'] = se[i + 1] * p_mean * (1 - p_mean)  # delta method approx
        res[f'{name}_p'] = pvalues[i + 1]

    return res


def write_table(results, filename, title, note=None):
    """Write regression results as markdown table."""
    if not results:
        return

    lines = [f"# {title}\n"]

    all_vars = []
    for r in results:
        for k in r:
            if k.endswith('_coef'):
                vname = k.replace('_coef', '')
                if vname not in all_vars:
                    all_vars.append(vname)

    model_labels = [r['model'] for r in results]
    header = "| Variable | " + " | ".join(model_labels) + " |"
    sep = "|:---|" + "|".join(["---:" for _ in results]) + "|"
    lines.append(header)
    lines.append(sep)

    for var in all_vars:
        coef_row = f"| {var} |"
        se_row = "| |"
        for r in results:
            if f'{var}_coef' in r:
                c, s = fmt(r[f'{var}_coef'], r[f'{var}_se'], r[f'{var}_p'])
                coef_row += f" {c} |"
                se_row += f" {s} |"
            else:
                coef_row += " |"
                se_row += " |"
        lines.append(coef_row)
        lines.append(se_row)

    lines.append("|:---|" + "|".join(["---:" for _ in results]) + "|")
    n_row = "| N |"
    r2_row = "| R² |"
    nc_row = "| Countries |"
    for r in results:
        n_row += f" {r['n_obs']} |"
        r2_row += f" {r['r_squared']:.4f} |"
        nc_row += f" {r['n_countries']} |"
    lines.append(n_row)
    lines.append(r2_row)
    lines.append(nc_row)

    if note:
        lines.append(f"\n{note}")
    else:
        lines.append("\n*Panel GLS with country and year fixed effects. "
                     "Standard errors in parentheses.*")
        lines.append("*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    path = OUT_TABLES / filename
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


# ── 1. Lagged Demographics (5-year lag) ──────────────────────────────

def lagged_trilemma(df):
    """5-year lagged Z → trilemma indices vs. contemporary Z."""
    print("\n" + "=" * 60)
    print("1. LAGGED DEMOGRAPHICS (5-YEAR LAG)")
    print("=" * 60)

    df = df.sort_values(['iso3', 'year'])
    for var in ['Z_1', 'Z_2', 'Z_3']:
        df[f'{var}_lag5'] = df.groupby('iso3')[var].shift(5)

    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw', 'kaopen']
    controls_no_kaopen = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw']

    trilemma_dvs = [
        ('mi_index', 'MI', controls),
        ('ers_index', 'ERS', controls),
        ('fo_index', 'FO', controls_no_kaopen),
    ]

    results = []

    for dv, dv_label, ctrl in trilemma_dvs:
        # Contemporary Z
        r = run_panel_gls(df, dv, ['Z_1', 'Z_2', 'Z_3'] + ctrl,
                          f'{dv_label} (Z_t)')
        if r: results.append(r)

        # 5-year lagged Z
        r = run_panel_gls(df, dv, ['Z_1_lag5', 'Z_2_lag5', 'Z_3_lag5'] + ctrl,
                          f'{dv_label} (Z_t-5)')
        if r: results.append(r)

    write_table(results, "lagged_trilemma.md",
                "Contemporary vs. 5-Year Lagged Demographics → Trilemma Indices")

    return df


# ── 2. First Differences ─────────────────────────────────────────────

def first_differences(df):
    """First-differenced Z → first-differenced trilemma indices."""
    print("\n" + "=" * 60)
    print("2. FIRST DIFFERENCES")
    print("=" * 60)

    df = df.sort_values(['iso3', 'year'])

    diff_vars = ['Z_1', 'Z_2', 'Z_3', 'mi_index', 'ers_index', 'fo_index',
                 'fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw', 'kaopen']
    for var in diff_vars:
        if var in df.columns:
            df[f'd_{var}'] = df.groupby('iso3')[var].diff()

    d_controls = ['d_fiscal_bal_gdp', 'd_nfa_gdp_lag', 'd_rgdp_growth',
                  'd_log_rel_opw', 'd_kaopen']
    d_controls_no_kaopen = ['d_fiscal_bal_gdp', 'd_nfa_gdp_lag', 'd_rgdp_growth',
                            'd_log_rel_opw']
    d_z = ['d_Z_1', 'd_Z_2', 'd_Z_3']

    trilemma_dvs = [
        ('d_mi_index', 'dMI', d_controls),
        ('d_ers_index', 'dERS', d_controls),
        ('d_fo_index', 'dFO', d_controls_no_kaopen),
    ]

    results = []
    for dv, dv_label, ctrl in trilemma_dvs:
        r = run_panel_gls(df, dv, d_z + ctrl, f'{dv_label}')
        if r: results.append(r)

    write_table(results, "first_diff_trilemma.md",
                "First-Differenced Demographics → Trilemma Indices",
                note=("*Panel GLS on first-differenced variables. "
                      "Standard errors in parentheses.*\n"
                      "*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*"))

    return df


# ── 3. OECD vs Non-OECD Subsample ────────────────────────────────────

def oecd_split(df):
    """OECD vs non-OECD for main trilemma regressions."""
    print("\n" + "=" * 60)
    print("3. OECD vs NON-OECD SUBSAMPLE")
    print("=" * 60)

    df['is_oecd_flag'] = df['iso3'].isin(OECD).astype(int)
    oecd_df = df[df['is_oecd_flag'] == 1].copy()
    non_oecd_df = df[df['is_oecd_flag'] == 0].copy()

    print(f"  OECD: {oecd_df['iso3'].nunique()} countries, {len(oecd_df)} obs")
    print(f"  Non-OECD: {non_oecd_df['iso3'].nunique()} countries, {len(non_oecd_df)} obs")

    z_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw', 'kaopen']
    controls_no_kaopen = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw']

    trilemma_dvs = [
        ('mi_index', 'MI', controls),
        ('ers_index', 'ERS', controls),
        ('fo_index', 'FO', controls_no_kaopen),
    ]

    results = []
    for dv, dv_label, ctrl in trilemma_dvs:
        for sub_df, sub_label in [(oecd_df, 'OECD'), (non_oecd_df, 'Non-OECD')]:
            r = run_panel_gls(sub_df, dv, z_vars + ctrl,
                              f'{dv_label} ({sub_label})')
            if r: results.append(r)

    write_table(results, "oecd_trilemma_split.md",
                "Trilemma Index Regressions: OECD vs. Non-OECD Subsample")


# ── 4. Pre/Post GFC Split ────────────────────────────────────────────

def gfc_split(df):
    """Pre-2008 vs post-2008 subsample for trilemma regressions."""
    print("\n" + "=" * 60)
    print("4. PRE/POST GFC SPLIT (2008)")
    print("=" * 60)

    pre_gfc = df[df['year'] < 2008].copy()
    post_gfc = df[df['year'] >= 2008].copy()

    print(f"  Pre-GFC: {pre_gfc['iso3'].nunique()} countries, {len(pre_gfc)} obs, "
          f"years {pre_gfc['year'].min()}-{pre_gfc['year'].max()}")
    print(f"  Post-GFC: {post_gfc['iso3'].nunique()} countries, {len(post_gfc)} obs, "
          f"years {post_gfc['year'].min()}-{post_gfc['year'].max()}")

    z_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw', 'kaopen']
    controls_no_kaopen = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw']

    trilemma_dvs = [
        ('mi_index', 'MI', controls),
        ('ers_index', 'ERS', controls),
        ('fo_index', 'FO', controls_no_kaopen),
    ]

    results = []
    for dv, dv_label, ctrl in trilemma_dvs:
        for sub_df, sub_label in [(pre_gfc, 'Pre-GFC'), (post_gfc, 'Post-GFC')]:
            r = run_panel_gls(sub_df, dv, z_vars + ctrl,
                              f'{dv_label} ({sub_label})')
            if r: results.append(r)

    write_table(results, "gfc_split.md",
                "Trilemma Index Regressions: Pre-GFC vs. Post-GFC",
                note=("*Panel GLS with country and year fixed effects. "
                      "Pre-GFC: year < 2008; Post-GFC: year >= 2008. "
                      "Standard errors in parentheses.*\n"
                      "*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*"))


# ── 5. Excluding Financial Centers ────────────────────────────────────

def excl_financial_centers(df):
    """Re-run main models excluding financial centers."""
    print("\n" + "=" * 60)
    print("5. EXCLUDING FINANCIAL CENTERS")
    print("=" * 60)
    print(f"  Excluding: {sorted(FINANCIAL_CENTERS)}")

    df_excl = df[~df['iso3'].isin(FINANCIAL_CENTERS)].copy()
    print(f"  Remaining: {df_excl['iso3'].nunique()} countries, {len(df_excl)} obs")

    z_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw', 'kaopen']
    controls_no_kaopen = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw']

    trilemma_dvs = [
        ('mi_index', 'MI', controls),
        ('ers_index', 'ERS', controls),
        ('fo_index', 'FO', controls_no_kaopen),
    ]

    results = []
    for dv, dv_label, ctrl in trilemma_dvs:
        # Full sample for comparison
        r = run_panel_gls(df, dv, z_vars + ctrl, f'{dv_label} (Full)')
        if r: results.append(r)

        # Excluding financial centers
        r = run_panel_gls(df_excl, dv, z_vars + ctrl, f'{dv_label} (Excl FC)')
        if r: results.append(r)

    write_table(results, "excl_fin_centers.md",
                "Trilemma Indices: Excluding Financial Centers (LUX, IRL, HKG, SGP, CHE, NLD, BEL)",
                note=("*Panel GLS with country and year fixed effects. "
                      "FC = financial centers (LUX, IRL, HKG, SGP, CHE, NLD, BEL). "
                      "Standard errors in parentheses.*\n"
                      "*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*"))


# ── 6. Logit vs LPM for Binary Outcomes ──────────────────────────────

def logit_vs_lpm(df):
    """Compare logit and LPM for mi_sacrifice and is_peg."""
    print("\n" + "=" * 60)
    print("6. LOGIT vs. LPM COMPARISON")
    print("=" * 60)

    z_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw', 'kaopen']
    x_vars = z_vars + controls

    results = []

    for dep_var, label in [('mi_sacrifice', 'MI Sacrifice'),
                           ('is_peg', 'Peg Choice')]:
        if dep_var not in df.columns:
            print(f"  {dep_var} not in panel, skipping")
            continue

        n_events = df[dep_var].sum()
        print(f"\n  {dep_var}: {n_events:.0f} events out of {df[dep_var].notna().sum()} obs")

        # LPM
        r_lpm = run_panel_gls(df, dep_var, x_vars, f'{label} (LPM)')
        if r_lpm: results.append(r_lpm)

        # Logit
        r_logit = run_logit(df, dep_var, x_vars, f'{label} (Logit)')
        if r_logit: results.append(r_logit)

    write_table(results, "logit_vs_lpm.md",
                "Logit vs. LPM Comparison: MI Sacrifice and Peg Choice",
                note=("*LPM: Panel GLS with country and year fixed effects. "
                      "Logit: pooled logit, marginal effects at means reported. "
                      "Standard errors in parentheses.*\n"
                      "*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*"))


# ── Main ─────────────────────────────────────────────────────────────

def main():
    print("=" * 70)
    print("PHASE 6: DYNAMICS & ROBUSTNESS")
    print("=" * 70)

    df = pd.read_csv(DATA / "trilemma_panel.csv")
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")
    print(f"Years: {df['year'].min()}-{df['year'].max()}")

    # 1. Lagged demographics
    df = lagged_trilemma(df)

    # 2. First differences
    df = first_differences(df)

    # 3. OECD vs non-OECD
    oecd_split(df)

    # 4. Pre/post GFC
    gfc_split(df)

    # 5. Excluding financial centers
    excl_financial_centers(df)

    # 6. Logit vs LPM
    logit_vs_lpm(df)

    print("\n" + "=" * 70)
    print("PHASE 6 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
